# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
os.getcwd()
import argparse
from pathlib import Path
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov8 import download_yolov8s_model
from ultralytics.utils.files import increment_path


def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False):
    """
    Run object detection on a video using YOLOv8 and SAHI.

    Args:
        weights (str): Model weights path.
        source (str): Video file path.
        view_img (bool): Show results.
        save_img (bool): Save results.
        exist_ok (bool): Overwrite existing files.
    """

    # Check source path
    if not Path(source).exists():
        raise FileNotFoundError(f"Source path '{source}' does not exist.")

    yolov8_model_path = f"{weights}"
    download_yolov8s_model(yolov8_model_path)
    detection_model = AutoDetectionModel.from_pretrained(
        model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.6, device="cpu"
    )
    if source[-3:] == 'mp4':
        # Video setup
        videocapture = cv2.VideoCapture(source)
        frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
        fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")

        # Output setup
        save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
        save_dir.mkdir(parents=True, exist_ok=True)
        video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))

        while videocapture.isOpened():
            success, frame = videocapture.read()
            if not success:
                break

            results = get_sliced_prediction(
                frame, detection_model, slice_height=256, slice_width=256, overlap_height_ratio=0.2, overlap_width_ratio=0.2
            )
            object_prediction_list = results.object_prediction_list

            boxes_list = []
            clss_list = []
            for ind, _ in enumerate(object_prediction_list):
                boxes = (
                    object_prediction_list[ind].bbox.minx,
                    object_prediction_list[ind].bbox.miny,
                    object_prediction_list[ind].bbox.maxx,
                    object_prediction_list[ind].bbox.maxy,
                )
                clss = object_prediction_list[ind].category.name
                boxes_list.append(boxes)
                clss_list.append(clss)

            for box, cls in zip(boxes_list, clss_list):
                x1, y1, x2, y2 = box
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
                label = str(cls)
                t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
                cv2.rectangle(
                    frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1
                )
                cv2.putText(
                    frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA
                )

            if view_img:
                cv2.imshow(Path(source).stem, frame)
            if save_img:
                video_writer.write(frame)

            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
        video_writer.release()
        videocapture.release()
        cv2.destroyAllWindows()
    else:
        results = get_sliced_prediction(
            source, detection_model, slice_height=256, slice_width=256, overlap_height_ratio=0.2, overlap_width_ratio=0.2
        )
        # 保存检测图片
        results.export_visuals(export_dir="demo_data/")


        image = cv2.imread('demo_data/prediction_visual.png')

        # 检查是否成功读取图片
        if image is not None:
            # 显示图片
            cv2.imshow('PNG Image', image)

            # 等待按键输入，并关闭窗口
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        else:
            print("Failed to read PNG image.")
def parse_opt():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
    parser.add_argument("--source", type=str, default='ultralytics/assets/bus.jpg', help="video file path or Photo")
    parser.add_argument("--view-img", action="store_true", default=True, help="show results")
    parser.add_argument("--save-img", action="store_true", help="save results")
    parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
    return parser.parse_args()


def main(opt):
    """Main function."""
    run(**vars(opt))


if __name__ == "__main__":
    opt = parse_opt()
    main(opt)
